import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

def initilize(layer):
    if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
    elif isinstance(layer, nn.Linear):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
        if layer.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.normal_(layer.bias, -bound, bound)
    elif isinstance(layer, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
        if layer.bias is not None:
            nn.init.constant_(layer.bias, 0)

class DCG(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        self.num_latent_feature = cfg.Generator.num_latent_feature
        self.num_classes = kwargs['num_classes']
        self.num_channels = kwargs['num_channels']
        self.num_class_embedding = cfg.Generator.num_class_embedding
        self.double_label = cfg.Generator.double_label
        self.embedding_channel = cfg.Generator.embedding_channel
        self.tanh = cfg.Generator.tanh

        self.label_embedding = nn.Sequential(
            nn.Embedding(self.num_classes, self.num_class_embedding),
            nn.Linear(self.num_class_embedding, 16))

        if self.double_label:
            self.second_label_embedding = nn.Sequential(
                nn.Embedding(self.num_classes, self.num_class_embedding),
                nn.Linear(self.num_class_embedding, 16))

        start_dim = self.num_latent_feature
        if self.double_label:
            start_dim = start_dim + 16 * 2 
        else:
            start_dim = start_dim + 16
     
        self.latent = nn.Sequential(
            nn.Linear(start_dim, 512), 
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512), 
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, self.embedding_channel * 8 * 8 ), 
            )

        self.convs = nn.Sequential(
            nn.ConvTranspose2d(
                self.embedding_channel, 32, kernel_size=4, 
                stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                32, self.num_channels, kernel_size=4, 
                stride=2, padding=1, bias=False),
        )
        if self.tanh:
            self.convs.append(nn.Tanh())

        self.convs.apply(initilize)
        self.latent.apply(initilize)
        self.label_embedding.apply(initilize)
        if self.double_label:
            self.second_label_embedding.apply(initilize)

    def forward(self, noise, label):

        if self.double_label:
            label_output = self.label_embedding(label[:,0])
            second_label_output = self.second_label_embedding(label[:,1])
            concat = torch.cat((label_output, second_label_output), dim=1)
        else:
            label_output = self.label_embedding(label)
            concat = label_output
        
        concat = torch.cat((noise, concat), dim=1)

        output = self.latent(concat)
        output = output.view(-1, self.embedding_channel,8,8)

        image = self.convs(output)
        return image

class LinearG(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()

        self.num_latent_feature = cfg.Generator.num_latent_feature
        self.hidden_dims = cfg.Generator.hidden_dims
        self.output_shape = kwargs["input_shape"]
        self.output_dim = np.prod(kwargs["input_shape"])
        self.double_label = cfg.Generator.double_label
        self.tanh = cfg.Generator.tanh
        
        self.num_classes = kwargs['num_classes']
        self.num_class_embedding = cfg.Generator.num_class_embedding

        self.label_embedding = nn.Sequential(
            nn.Embedding(self.num_classes, self.num_class_embedding),
            nn.Linear(self.num_class_embedding, 16))

        if self.double_label:
            self.second_label_embedding = nn.Sequential(
                nn.Embedding(self.num_classes, self.num_class_embedding),
                nn.Linear(self.num_class_embedding, 16))
    
        start_dim = self.num_latent_feature
        if self.double_label:
            start_dim = start_dim + 16 * 2 
        else:
            start_dim = start_dim + 16

        self.model = nn.Sequential()
        for input_dim, output_dim in zip(
                [start_dim] + self.hidden_dims[:-1],
                self.hidden_dims
            ):
            self.model.append(nn.Linear(input_dim, output_dim))
            self.model.append(nn.BatchNorm1d(output_dim))
            self.model.append(nn.ReLU())

        self.model.append(nn.Linear(self.hidden_dims[-1], self.output_dim))
        if self.tanh:
            self.model.append(nn.Tanh())
        self.model.apply(initilize)
        self.label_embedding.apply(initilize)
        if self.num_guiders > 1:
            self.guiders_embedding.apply(initilize)
        if self.double_label:
            self.second_label_embedding.apply(initilize)

    def forward(self, noise, label):
        if self.double_label:
            label_output = self.label_embedding(label[:,0])
            second_label_output = self.second_label_embedding(label[:,1])
            concat = torch.cat((label_output, second_label_output), dim=1)
        else:
            label_output = self.label_embedding(label)
            concat = label_output
        
        concat = torch.cat((noise, concat), dim=1)

        output = self.model(concat)
        output = output.view(-1, * self.output_shape)
        return output

class IdentityG(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()

        self.batch_size = cfg.Generator.batch_size
        self.noise_size = kwargs["input_shape"]

        self.noise = nn.Parameter(torch.zeros((self.batch_size, *self.noise_size)))
        torch.nn.init.normal_(self.noise, mean=0, std=0.1)

    def forward(self, noise, label, guider_index):
        return self.noise